import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
from threading import Thread
import re
import time
from PIL import Image
import torch
import spaces
#import subprocess
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)


local_model_dir = "E:/LLM/SmolVLM/2B/"
local_tokenizer_dir = "E:/LLM/SmolVLM/2B/"

processor = AutoProcessor.from_pretrained(local_model_dir, local_files_only=True)
model = AutoModelForVision2Seq.from_pretrained(local_tokenizer_dir, 
        torch_dtype=torch.bfloat16,
        #_attn_implementation="flash_attention_2"
        ).to("cuda")

@spaces.GPU
def model_inference(
    input_dict, history, decoding_strategy, temperature, max_new_tokens,
    repetition_penalty, top_p
): 
    # 从输入字典中获取文本
    text = input_dict["text"]
    # 打印输入字典中的文件列表
    print(input_dict["files"])
    # 如果文件列表长度大于1，则打开所有图片并转换为RGB格式
    if len(input_dict["files"]) > 1:
      images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
    # 如果文件列表长度等于1，则打开单个图片并转换为RGB格式
    elif len(input_dict["files"]) == 1:
      images = [Image.open(input_dict["files"][0]).convert("RGB")]   
    # 如果文件列表为空，则将图片列表设为空
    else:
      images = []
    

    # 如果文本为空且没有图片，则抛出错误
    if text == "" and not images:
        gr.Error("Please input a query and optionally image(s).")

    # 如果文本为空且有图片，则抛出错误
    if text == "" and images:
        gr.Error("Please input a text query along the image(s).")

    


    # 构建用户消息列表，包含图片和文本
    resulting_messages = [
                {
                    "role": "user",
                    "content": [{"type": "image"} for _ in range(len(images))] + [
                        {"type": "text", "text": text}
                    ]
                }
            ]
    # 使用处理器将消息列表转换为提示
    prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
    # 使用处理器将文本和图片转换为模型输入
    inputs = processor(text=prompt, images=[images], return_tensors="pt")
    # 将输入数据移动到GPU
    inputs = {k: v.to("cuda") for k, v in inputs.items()}
    # 设置生成参数
    generation_args = {
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,

    }

    # 确保解码策略在允许的范围内
    assert decoding_strategy in [
        "Greedy",
        "Top P Sampling",
    ]
    # 如果解码策略为Greedy，则不进行采样
    if decoding_strategy == "Greedy":
        generation_args["do_sample"] = False
    # 如果解码策略为Top P Sampling，则进行采样并设置温度和top_p参数
    elif decoding_strategy == "Top P Sampling":
        generation_args["temperature"] = temperature
        generation_args["do_sample"] = True
        generation_args["top_p"] = top_p

    # 更新生成参数
    generation_args.update(inputs)
    # Generate
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens= True)
    generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
    generated_text = ""

    thread = Thread(target=model.generate, kwargs=generation_args)
    thread.start()

    yield "..."
    buffer = ""
    
      
    for new_text in streamer:
    
      buffer += new_text
      generated_text_without_prompt = buffer#[len(ext_buffer):]
      time.sleep(0.01)
      yield buffer


examples=[
              [{"text": "What art era do these artpieces belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8],
              [{"text": "I'm planning a visit to this temple, give me travel tips.", "files": ["example_images/examples_wat_arun.jpg"]},  "Greedy", 0.4, 512, 1.2, 0.8],
              [{"text":  "What is the due date and the invoice date?", "files": ["example_images/examples_invoice.png"]}, "Greedy", 0.4, 512, 1.2, 0.8],
              [{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]},   "Greedy", 0.4, 512, 1.2, 0.8],
              [{"text":  "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]},  "Greedy", 0.4, 512, 1.2, 0.8],
      ]
demo = gr.ChatInterface(fn=model_inference, title="SmolVLM: Small yet Mighty 💫", 
                description="Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This checkpoint works best with single turn conversations, so clear the conversation after a single turn.",
                examples=examples,
                textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, 
                        additional_inputs=[gr.Radio(["Top P Sampling",
              "Greedy"],
          value="Greedy",
          label="Decoding strategy",
          #interactive=True,
          info="Higher values is equivalent to sampling more low-probability tokens.",
                                                  
      ), gr.Slider(
          minimum=0.0,
          maximum=5.0,
          value=0.4,
          step=0.1,
          interactive=True,
          label="Sampling temperature",
          info="Higher values will produce more diverse outputs.",
      ),
                                            gr.Slider(
          minimum=8,
          maximum=1024,
          value=512,
          step=1,
          interactive=True,
          label="Maximum number of new tokens to generate",
      ), gr.Slider(
          minimum=0.01,
          maximum=5.0,
          value=1.2,
          step=0.01,
          interactive=True,
          label="Repetition penalty",
          info="1.0 is equivalent to no penalty",
      ), 
         gr.Slider(
          minimum=0.01,
          maximum=0.99,
          value=0.8,
          step=0.01,
          interactive=True,
          label="Top P",
          info="Higher values is equivalent to sampling more low-probability tokens.",
      )],cache_examples=False
                )
     
      
      

demo.launch(debug=True)
        
